Skip to content

Support for WMMA instructions for RDNA4 GPUs#929

Open
ffrancesco94 wants to merge 13 commits into
JuliaGPU:mainfrom
ffrancesco94:feat/wmma-rdna4
Open

Support for WMMA instructions for RDNA4 GPUs#929
ffrancesco94 wants to merge 13 commits into
JuliaGPU:mainfrom
ffrancesco94:feat/wmma-rdna4

Conversation

@ffrancesco94

Copy link
Copy Markdown

As per title. Most things work in the same way as RDNA3, except for the fact that you don't need data duplication as RDNA4 lanes have 8 elements and not 16 (8*2 as it was before). I kept both implementation and tests separate from the RDNA3 version, but maybe in the future they could be kind of merged with a runtime dispatch based on the hardware.

DISCLAIMER: I did use Mistral vibe to get the first draft and then took development from there. I do not know if there is a specific policy against use of AI for pull requests to this project, thus I understand if you're not willing to look at it. The tests pass on my RX 9070XT and even if you diff with the RDNA3 version you'll see that the logic is the same and what changes is the shape of the fragments and addressing in the lanes.

ffrancesco94 and others added 8 commits June 5, 2026 22:42
This commit adds Wave Matrix Multiply Accumulate (WMMA) instruction support
for AMD's RDNA4 architecture GPUs (gfx1200+).

Changes:
- Add WMMA_RDNA4 module in src/device/gcn/wmma_rdna4.jl
- Support for new RDNA4 WMMA intrinsics with _gfx12 suffix
- Simplified VGPR layout (no data duplication, 8 elements per thread)
- Support for Float16 and BFloat16 types (FP8 types ready for future addition)
- Add comprehensive tests in test/wmma_rdna4_tests.jl
- Update documentation with RDNA4 section and examples
- Update existing WMMA tests to also detect RDNA4

Architectural Differences from RDNA3:
- Each lane handles 8 elements (vs 16 with duplication in RDNA3)
- New intrinsic names with _gfx12 suffix and explicit vector type annotations
- Subtarget feature: wmma-128b-insts (vs gfx11-insts for RDNA3)
- Cleaner VGPR distribution with no data duplication

References:
- AMD GPUOpen: https://gpuopen.com/learn/using_matrix_core_amd_rdna4/
- LLVM commit: llvm/llvm-project@829afc4

Generated by Mistral Vibe.
Co-Authored-By: Mistral Vibe <vibe@mistral.ai>
Added tile pointer and stride helper functions for WMMA_RDNA4.
Updated example code block to use Julia syntax highlighting.
@ffrancesco94

Copy link
Copy Markdown
Author

My last commits fail the CI but if I look at the logs there are no errors actually, I don't know if it's related to the updates to the buildkite.

@github-actions github-actions Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AMDGPU.jl Benchmarks

Details
Benchmark suite Current: 9ee1ce0 Previous: 756602c Ratio
amdgpu/synchronization/context/device 600 ns 600 ns 1
amdgpu/synchronization/stream/blocking 260 ns 240 ns 1.08
amdgpu/synchronization/stream/nonblocking 340 ns 340 ns 1
array/accumulate/Float32/1d 84941 ns 86251 ns 0.98
array/accumulate/Float32/dims=1 383596 ns 393845 ns 0.97
array/accumulate/Float32/dims=1L 134982 ns 131681 ns 1.03
array/accumulate/Float32/dims=2 130392 ns 103022 ns 1.27
array/accumulate/Float32/dims=2L 2809690 ns 2827930 ns 0.99
array/accumulate/Int64/1d 98541 ns 96412 ns 1.02
array/accumulate/Int64/dims=1 288534 ns 285244 ns 1.01
array/accumulate/Int64/dims=1L 167452 ns 160812 ns 1.04
array/accumulate/Int64/dims=2 123992 ns 120772 ns 1.03
array/accumulate/Int64/dims=2L 2983033 ns 3014433 ns 0.99
array/broadcast 133412 ns 128932 ns 1.03
array/construct 1680 ns 1680 ns 1
array/copy 38771 ns 39371 ns 0.98
array/copyto!/cpu_to_gpu 183313 ns 114832 ns 1.60
array/copyto!/gpu_to_cpu 183493 ns 152432 ns 1.20
array/copyto!/gpu_to_gpu 128142 ns 88321 ns 1.45
array/iteration/findall/bool 179452 ns 181912 ns 0.99
array/iteration/findall/int 189393 ns 190933 ns 0.99
array/iteration/findfirst/bool 124201 ns 114451 ns 1.09
array/iteration/findfirst/int 114811 ns 116331 ns 0.99
array/iteration/findmin/1d 169482 ns 166203 ns 1.02
array/iteration/findmin/2d 155612 ns 156173 ns 1.00
array/iteration/logical 353385 ns 346025 ns 1.02
array/iteration/scalar 295995 ns 289864 ns 1.02
array/permutedims/2d 74041 ns 64761 ns 1.14
array/permutedims/3d 74331 ns 73791 ns 1.01
array/permutedims/4d 77391 ns 76481 ns 1.01
array/random/rand/Float32 54581 ns 51540 ns 1.06
array/random/rand/Int64 57501 ns 56210 ns 1.02
array/random/rand!/Float32 146032 ns 142162 ns 1.03
array/random/rand!/Int64 147413 ns 141832 ns 1.04
array/random/randn/Float32 99331 ns 86921 ns 1.14
array/random/randn!/Float32 87222 ns 152202 ns 0.57
array/reductions/mapreduce/Float32/1d 130992 ns 132902 ns 0.99
array/reductions/mapreduce/Float32/dims=1 93392 ns 95052 ns 0.98
array/reductions/mapreduce/Float32/dims=1L 774481 ns 777081 ns 1.00
array/reductions/mapreduce/Float32/dims=2 97692 ns 96731 ns 1.01
array/reductions/mapreduce/Float32/dims=2L 297145 ns 299584 ns 0.99
array/reductions/mapreduce/Int64/1d 134432 ns 133322 ns 1.01
array/reductions/mapreduce/Int64/dims=1 95691 ns 78081 ns 1.23
array/reductions/mapreduce/Int64/dims=1L 782341 ns 783471 ns 1.00
array/reductions/mapreduce/Int64/dims=2 96462 ns 96252 ns 1.00
array/reductions/mapreduce/Int64/dims=2L 303244 ns 308254 ns 0.98
array/reductions/reduce/Float32/1d 133771 ns 132802 ns 1.01
array/reductions/reduce/Float32/dims=1 95121 ns 94832 ns 1.00
array/reductions/reduce/Float32/dims=1L 773901 ns 774621 ns 1.00
array/reductions/reduce/Float32/dims=2 97152 ns 96802 ns 1.00
array/reductions/reduce/Float32/dims=2L 296735 ns 307245 ns 0.97
array/reductions/reduce/Int64/1d 134502 ns 129672 ns 1.04
array/reductions/reduce/Int64/dims=1 95331 ns 78151 ns 1.22
array/reductions/reduce/Int64/dims=1L 782231 ns 781931 ns 1.00
array/reductions/reduce/Int64/dims=2 96621 ns 96192 ns 1.00
array/reductions/reduce/Int64/dims=2L 296015 ns 298414 ns 0.99
array/reverse/1d 44641 ns 44380 ns 1.01
array/reverse/1dL 75261 ns 74131 ns 1.02
array/reverse/1dL_inplace 127212 ns 108282 ns 1.17
array/reverse/1d_inplace 78901 ns 86471 ns 0.91
array/reverse/2d 51100 ns 50661 ns 1.01
array/reverse/2dL 101772 ns 100341 ns 1.01
array/reverse/2dL_inplace 135532 ns 117622 ns 1.15
array/reverse/2d_inplace 79061 ns 95391 ns 0.83
array/sorting/1d 340625 ns 341945 ns 1.00
integration/byval/reference 39621 ns 38830 ns 1.02
integration/byval/slices=1 40880 ns 40880 ns 1
integration/byval/slices=2 145702 ns 158462 ns 0.92
integration/byval/slices=3 237713 ns 238013 ns 1.00
integration/volumerhs 5034992 ns 4942659 ns 1.02
kernel/indexing 104961 ns 43630 ns 2.41
kernel/indexing_checked 131991 ns 128022 ns 1.03
kernel/launch 1340 ns 1290 ns 1.04
kernel/rand 203133 ns 106671 ns 1.90
latency/import 1486594790 ns 1501349912 ns 0.99
latency/precompile 11980399330 ns 12041117438 ns 0.99
latency/ttfp 10904391209 ns 10491950084 ns 1.04

This comment was automatically generated by workflow using github-action-benchmark.

@luraess

luraess commented Jun 15, 2026

Copy link
Copy Markdown
Member

Yes, we had issues with the runners. I relaunched the failed jobs. Would the current status be ready for review or are you still working on it?

@ffrancesco94

Copy link
Copy Markdown
Author

If it passes the CI I'm happy to get it reviewed!

@luraess

luraess commented Jun 17, 2026

Copy link
Copy Markdown
Member

The addition looks good to me. I was curious if you had specific idea in mind on how to achieve

but maybe in the future they could be kind of merged with a runtime dispatch based on the hardware.

@ffrancesco94

Copy link
Copy Markdown
Author

I haven't given it much thought, but something along the lines of a trait, where we have empty structs RDNA4 and RDNA3 as subtypes of a RDNAArch abstract struct and we redefine mma to take also this struct as an argument so that it can dispatch to the right intrinsics. Then what we actually export is a mma method which checks the architecture and calls the appropriate version. I don't know if it makes sense, I haven't really tried because I only have a RDNA4 GPU so I can't test this dispatch myself (and the buildkite runner only seems to have a RDNA3 from what I can see) and also because it was my first contribution and I didn't want to mess up too much with the existing codebase. When I get more comfortable I'll try to get into that :)

@luraess

luraess commented Jun 19, 2026

Copy link
Copy Markdown
Member

Thanks, feel free to open an issue reporting the above thoughts so we could have a trace for future todos.

One last thing, would it make sense to explicitly have WMMA_RDNA3 (by analogy to now WMMA_RDNA4 and possible future ones) instead of WMMA which stands for RDNA_3 only as to be consistent?

@ffrancesco94

Copy link
Copy Markdown
Author

It makes sense and I have changed that. I just hope it doesn't break anything downstream for the people who were referring to WMMA directly. Buildkite running now.

Comment thread test/wmma_tests.jl Outdated
@luraess

luraess commented Jun 22, 2026

Copy link
Copy Markdown
Member

@pxl-th is it fine for you to rename the WMMA -> WMMA_3 for consistency with 4?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants